-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[luci/pass] Introduce FuseGRU Pass #14252
base: master
Are you sure you want to change the base?
Conversation
This PR introduces FuseGRUPass for fusing decomposed gru pattern into single CircleGRU. ONE-DCO-1.0-Signed-off-by: Artem Balyshev <a.balyshev@samsung.com> ONE-DCO-1.0-Signed-off-by: Chunseok Lee <chunseok.lee@samsung.com>
if (_while_node == nullptr) | ||
return false; | ||
|
||
// 1 - check condition graph: only one Less operation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only one Less operation
and below condition doesn't match.
- 1/ fix comment like
Less operation should exist
- 2/ fix implementation to check only one
Less
exist
break; | ||
} | ||
|
||
// doesn't find Less node |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMHO, this comment can be removed
if (fc_nodes.size() != 2 or mul_nodes.size() != 3 or logistic_nodes.size() != 2 or | ||
split_nodes.size() != 2 or add_nodes.size() != 6 or gather_nodes.size() != 1 or | ||
reshape_nodes.size() != 1 or sub_nodes.size() != 1 or tanh_nodes.size() != 1 or | ||
split_out_nodes.size() != 6) | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I commented in the draft code, I don't agree on only check number of Ops to check.
split_out_nodes.size() != 6) | ||
return false; | ||
|
||
// Check structure |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the algorithm of below check codes.
Please explain what is happening.
luci::CircleGRU *create_circle_gru(loco::Graph *graph); | ||
|
||
private: | ||
const GRUPatternBase *_p; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const GRUPatternBase *_p; | |
// initialized at ctor | |
const GRUPatternBase *_p; |
break; | ||
|
||
default: | ||
assert(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's throw instead of assert
} | ||
else | ||
{ | ||
bias_ih_cloned = _p->_pattern_last_node->graph()->nodes()->create<luci::CircleOutputExclude>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bias_ih_cloned = _p->_pattern_last_node->graph()->nodes()->create<luci::CircleOutputExclude>(); | |
bias_ih_cloned = graph->nodes()->create<luci::CircleOutputExclude>(); |
?
|
||
luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) | ||
{ | ||
assert(graph); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this graph
and _p->_pattern_last_node->graph()
are different, please add a not about this.
As I understand, we're looking with multiple sub graph
objects so graph
ptr can be different.
void invalid_less_const_type() { _less_const_node->dtype(loco::DataType::S16); } | ||
|
||
protected: | ||
luci::CircleWhile *_while_node; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
luci::CircleWhile *_while_node; | |
luci::CircleWhile *_while_node = nullptr; |
and others in below too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated as suggested.
* | | ||
* [Out_1] | ||
*/ | ||
class GRUPattern1 final : public GRUPatternBase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume 1
suffix in GRUPattern1
is to prepare more patterns.
Please leave a note about thus.
Or please use just GRUPattern
here and later we can rename this when we add more patterns.
|
||
g.init(); | ||
|
||
EXPECT_FALSE(pass.run(g.g())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding a note about what negative would help understanding FuseGRUTestNegGraph
.
It's a bunch of nodes connected and I can't catch what it's doing.
Co-authored-by: SaeHie Park <saehie.park@gmail.com>
@@ -551,7 +551,7 @@ luci::CircleConst *clone_circleconst(luci::CircleConst *node, loco::Graph *graph | |||
break; | |||
|
|||
default: | |||
assert(false); | |||
throw std::runtime_error("Unsupported data type"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw std::runtime_error("Unsupported data type"); | |
throw std::runtime_error("FuseGRU: Unsupported data type"); |
This PR introduces FuseGRUPass for fusing decomposed gru pattern into single CircleGRU.
draft : #14237
issue : #12263
ONE-DCO-1.0-Signed-off-by: Artem Balyshev a.balyshev@samsung.com
ONE-DCO-1.0-Signed-off-by: Chunseok Lee chunseok.lee@samsung.com